feat(aggregation): Add MoDoWeighting#717
Conversation
PierreQuinton
left a comment
There was a problem hiding this comment.
Looking good, it still needs few changes but once this is merge, I think this makes #676 easier to merge.
|
|
||
| with torch.no_grad(): | ||
| grad = gramian @ lambd + self._rho * lambd | ||
| lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) |
There was a problem hiding this comment.
So in the end, this is a softmax. @rkhosrowshahi I think this means that moco is essentially just a composition with this weighting, where essentially you give yy_t to it, and then multiply yy_t by the obtained weights. Is that correct? If yes, I think we should change #676 accordingly.
PierreQuinton
left a comment
There was a problem hiding this comment.
For me this is ready, let's wait for @ValerianRey 's review s still.
ValerianRey
left a comment
There was a problem hiding this comment.
I don't think this is equivalent to the paper or the official or libmtl implementation. In all of these, the gramian is computed as J_1 @ J_2^T (I think this aggregator makes 0 sense for IWRM because of that, so we gotta think of it in MTL context). Autojac's gramian computed on losses_1 would be J_1 @ J_1^T though, so I think this PR's usage example is wrong.
See equation 2.9a, step 3 of the algorithm, or line 55 of the libmtl implementation.
I think the only way to add MoDo to torchjd would be with the same implementation but different usage example:
- user computes J_1 and J_2 using autojac.jac
- user compute G = J_1 @ J_2^T
- user computes weights by applying a MoDoWeighting to G
- user does an extra backward pass with some new losses (losses_3) weighted with the obtained weights.
I think depending on the implementation they either use only losses_1 and losses_2, or they also use losses_3. Idk what's best.
Note that here G is not a gramian, and is not PSD in general. Gotta type MoDoWeighting properly.
What do you think @PierreQuinton @KhusPatel4450 ?
| - Added `MoDoWeighting` from [Three-Way Trade-Off in Multi-Objective Learning: Optimization,Generalization and Conflict-Avoidance](https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf) (JMLR 2024). It is a stateful `Weighting` that maintains task weights across calls via a | ||
| softmax-projected gradient step on the Gramian, intended to be composed with `autogram.Engine` | ||
| in a two-batch training loop. |
There was a problem hiding this comment.
Can remove the description of what MoDo is, or fix it (it's not doable with autogram with the changes I suggest).
| from ._weighting_bases import _GramianWeighting | ||
|
|
||
|
|
||
| class MoDoWeighting(_GramianWeighting, Stateful, _NonDifferentiable): |
There was a problem hiding this comment.
Can't inherit from _GramianWeighting anymore with the changes I propose. Need to either inherit from _MatrixWeighting, or inherit from Weighting[Matrix] and override call to make a more specific docstring (e.g. explaining exactly what the matrix is: J_1 @ J_2^T)
Also need to fix the main docstring accordingly.
| <https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf>`_ (JMLR 2024), commonly referred | ||
| to as MoDo (Multi-Objective gradient with Double sampling). |
There was a problem hiding this comment.
Arguably the whole method (given in the usage example) is called Modo. Can remove this last part. The part explaining the acronym (in parentheses) is nice though.
| <https://www.jmlr.org/papers/volume25/23-1287/23-1287.pdf>`_ (JMLR 2024), commonly referred | ||
| to as MoDo (Multi-Objective gradient with Double sampling). | ||
|
|
||
| Given a Gramian :math:`G`, the weights :math:`\lambda` are updated at each call by a |
There was a problem hiding this comment.
Not a gramian anymore with the changes i propose.
| .. warning:: | ||
| MoDo's convergence guarantees rely on **double sampling**: the Gramian passed to this | ||
| weighting must come from a mini-batch that is independent of the one used for the | ||
| subsequent parameter update. The Gramian can be computed efficiently from a batch of | ||
| losses using the :class:`~torchjd.autogram.Engine`. See the usage example below. |
There was a problem hiding this comment.
This is a bit wrong, we can't talk about a Gramian there and it should itself come from 2 batches. The last part about autogram should be removed.
| Train a model using MoDo with two independent mini-batches per step. The first batch | ||
| drives the :math:`\lambda` update via the Gramian; the second batch drives the parameter | ||
| update via the usual backward pass. |
There was a problem hiding this comment.
I would simplify this explanation to make it clear that this is just doing basic MoDo, so that people actually follow this usage example if they wanna reproduce MoDo.
Something like:
Train a model with MoDo.
The role of their paper is to explain what MoDo is. The role of TorchJD is to make it extremely easy to reproduce MoDo, but not to explain it IMO (especially since it's quite complex).
|
Hello, I think you are correct, going back and looking at the paper again and looking at equation 2.8 and 2.9, you are right. We originally discussed to use Autogram to compute gramian efficiently, at the time it seemed like the right fit since autogram gives The matrix in the λ update is So the changes needed are:
One open question on the model update (equation 2.9b): it uses a fresh |
|
Thanks for the quick reply. I agree. Please go ahead.
Idk, gotta investigate that. I'll look at that soon-ish. |
|
For now you could just add the two usage examples. Later, we can decide to keep only one, or to change the description to say something like: "this is the default behavior of MoDo in the official implementation and in LibMTL" and "this is an alternative implementation that has the advantage of [...]" |
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
This comment has been minimized.
|
My main concern is now fixed, tyvm @KhusPatel4450. I still need to review more in-depth (since there's a lot of code and tests). But I'm pretty sure we'll be able to merge this very soon! |
ValerianRey
left a comment
There was a problem hiding this comment.
We're getting there! I just made a thorough review so I think we can merge whenever everything is fixed.
| The following example reproduces basic MoDo using two independent mini-batches per step. | ||
| .. code-block:: python |
There was a problem hiding this comment.
Need to use the testcode directive so that this usage example is tested by doctest (in the CI or manually using uv run make -C ./docs doctest), or no block and prefix each line with >>>
Otherwise the example isn't tested by doctest and we have no guarantee that it works and that it keeps working in future updates.
Same comment for the other usage example, and even for any future example you may write.
There was a problem hiding this comment.
Would it be possible to have that inside a skill related to documentation for agents?
There was a problem hiding this comment.
Too simple for a skill but we should definitely have that explained in AGENTS.md and / or contributing.md, or even have a script checking that we don't use any code-block directive. Feel free to make a PR with what you think is best.
| params = list(model.parameters()) | ||
| # loader_1 and loader_2 must yield independent draws of the same size. | ||
| for batch_1, batch_2 in zip(loader_1, loader_2): |
There was a problem hiding this comment.
We need the example to be self-sufficient, so we either need to define loader_1 and loader_2 (which is tedious) or just do something like:
inputs = ...
targets = ...
for i in range(len(input) // 2):
input_1, input_2 = inputs[2*i], inputs[2*i + 1]
target_1, target_2 = targets[2*i], targets[2*i + 1]
...With the ... filled with the appriopriate code.
Similar comment for the other example.
There was a problem hiding this comment.
To better reflect the new usage, I think we should use G = J1 @ J2.T instead of G = J @ J.T in test_reset_restores_first_step_behavior, test_output_lies_on_simplex, test_update_recurrence and test_changing_m_auto_resets.
Similarly, we should use G1 = J1 @ J2.T and G2 = J3 @ J4.T in test_two_consecutive_steps.
| def test_small_gamma_stays_near_uniform() -> None: | ||
| """With a tiny gamma, one step barely moves lambda from the uniform initialisation.""" | ||
|
|
||
| J = randn_((3, 8)) | ||
| G = J @ J.T | ||
| m = J.shape[0] | ||
| W = MoDoWeighting(gamma=1e-8) | ||
| uniform = tensor_([1.0 / m] * m) | ||
| assert_close(W(G), uniform, atol=1e-6, rtol=1e-6) | ||
|
|
||
|
|
There was a problem hiding this comment.
I think we can remove this test.
| def test_small_gamma_stays_near_uniform() -> None: | |
| """With a tiny gamma, one step barely moves lambda from the uniform initialisation.""" | |
| J = randn_((3, 8)) | |
| G = J @ J.T | |
| m = J.shape[0] | |
| W = MoDoWeighting(gamma=1e-8) | |
| uniform = tensor_([1.0 / m] * m) | |
| assert_close(W(G), uniform, atol=1e-6, rtol=1e-6) |
| optimizer.zero_grad() | ||
| """ | ||
|
|
||
| def __init__(self, gamma: float = 0.1, rho: float = 0.0) -> None: |
There was a problem hiding this comment.
In the official implementation and in LibMTL, the default value of rho is 0.1. It would be better to match this IMO.
| G = J_1 @ J_2.T | ||
| weights = weighting(G) | ||
| losses_2.backward(weights) |
There was a problem hiding this comment.
If we follow Equation 2.9b from the paper, this should be
losses = ((losses_1 + losses_2) / 2.0)
losses.backward(weights)In the official implementation, it's also what they do, except that they forgot the division by 2.
| lambd = cast(Tensor, self._lambda) | ||
|
|
||
| grad = matrix @ lambd + self._rho * lambd | ||
| lambd = torch.softmax(lambd - self._gamma * grad, dim=-1) |
There was a problem hiding this comment.
I think there was some confusion on discord when we talked about how to project onto the simplex. We all thought that the official implementation was using a softmax, but it (and LibMTL) actually uses:
def _projection2simplex(self, y):
m = len(y)
sorted_y = torch.sort(y, descending=True)[0]
tmpsum = 0.0
tmax_f = (torch.sum(y) - 1.0)/m
for i in range(m-1):
tmpsum+= sorted_y[i]
tmax = (tmpsum - 1)/ (i+1.0)
if tmax > sorted_y[i+1]:
tmax_f = tmax
break
return torch.max(y - tmax_f, torch.zeros(m).to(y.device))Should we use this way of projecting @PierreQuinton ?
If we do that, we'll need to say that parts of this file were adapted from the official implementation, add a link to it, and add a notice in NOTICES @KhusPatel4450.
There was a problem hiding this comment.
I think I know what happened now, the code that I was told to read was from Rasa's MoCo.py and that used torch.softmax, but yeah now I see that it uses this.
I personally think we should follow this
| .. admonition:: Example (three batches per step) | ||
| The following example reproduces basic MoDo using three independent mini-batches per step, |
There was a problem hiding this comment.
Maybe we could add that this is the behavior of MoDo in LibMTL and in the official implementation when three_grads is True.
| .. admonition:: Example (two batches per step) | ||
| The following example reproduces basic MoDo using two independent mini-batches per step. |
There was a problem hiding this comment.
Maybe we could add that this is MoDo as described in the paper, and it's the behavior of the official implementation when three_grads is False.
|
this commit has everything EXCEPT for the projection onto simplex change, I will add a new commit for that once we reach upon a conclusion @ValerianRey |
|
this commit now has the projection onto simplex as discussed on discord |
ValerianRey
left a comment
There was a problem hiding this comment.
LGTM, we can merge! Thanks a lot for the work @KhusPatel4450
|
@PierreQuinton feel free to merge or re-review if you prefer |
|
Congrats on the merge @KhusPatel4450 and thanks again. |
Adds
MoDoWeightingfrom Three-Way Trade-Off in Multi-Objective Learning: Optimization, Generalization and Conflict-Avoidance (JMLR 2024).It's a stateful
Weighting[PSDMatrix]implementing the λ-update from Algorithm 2:Per the discussion with @PierreQuinton and @ValerianRey on Discord, this follows the official LibMTL implementation which uses
softmaxrather than the paper's hard simplex projection.Designed to be composed with
autogram.Enginein a two-batch training loop so that MoDo's double-sampling property is preserved (Gramian comes from batch 1; backward uses batch 2).Test plan
tests/unit/aggregation/test_modo.py(12 functions, 72 cases — structural, reset, parameter validation, softmax boundary cases, recurrence verification)ty checkpasses on_modo.py-W --keep-going -nEOF
)"